# -*- coding:utf-8 -*-

import os
import codecs
import pickle
import numpy as np
import config.glob.config as global_cfg
from utils.common_utils import Singleton
from utils.log_utils import log_debug
from config.glob.global_pool import global_pool

"""
构建词汇表
"""


class VocabList(Singleton):
    """
    词汇表单例, 后续使用one-hot表征单词
    1.vocab (按词频排序的词list)
    2.word_to_int_table (格式如dict(char, idx)的字典)
    3.int_to_word_table (格式如dict(idx, char)的字典)
    """
    def __init__(self, text=None, max_vocab=5000, filename=None):
        if filename is not None:  # 从文件加载
            with open(filename, 'rb') as f:
                self.vocab = pickle.load(f)
        else:  # 构建
            self.vocab = None
            self.build_vocab(text, max_vocab)
        self.word_to_int_table = {c: i for i, c in enumerate(self.vocab)}  # dict(char, idx)
        self.int_to_word_table = dict(enumerate(self.vocab))  # dict(idx, char)

    def build_vocab(self, text, max_vocab):
        """
        构建词汇表
        :param text:
        :param max_vocab:
        :return:
        """
        vocab = set(text)  # 字符去重
        log_debug('\033[32mvocab num:{}\033[0m'.format(len(vocab)))

        # 词频表, dict(字符: 词频)
        vocab_count = {}
        for word in vocab:
            vocab_count[word] = 0
        for word in text:
            vocab_count[word] += 1

        # 词频表, list(字符: 词频)
        vocab_count_list = []
        for word in vocab_count:
            vocab_count_list.append((word, vocab_count[word]))

        # 根据频次排序
        vocab_count_list.sort(key=lambda x: x[1], reverse=True)
        # 裁剪到max_vocab
        if len(vocab_count_list) > max_vocab:
            vocab_count_list = vocab_count_list[:max_vocab]

        # 保留字符
        vocab = [x[0] for x in vocab_count_list]
        self.vocab = vocab

    @property
    def vocab_size(self):
        """
        词汇表数量
        :return:
        """
        return len(self.vocab) + 1

    def word_to_int(self, word):
        """
        文字转idx
        :param word:
        :return:
        """
        if word in self.word_to_int_table:
            return self.word_to_int_table[word]
        else:
            return len(self.vocab)

    def int_to_word(self, index):
        """
        idx转文字
        :param index:
        :return:
        """
        if index == len(self.vocab):
            return '<unk>'
        elif index < len(self.vocab):
            return self.int_to_word_table[index]
        else:
            raise Exception('Unknown index!')

    def text_to_arr(self, text):
        """
        句子转idx数组
        :param text:
        :return:
        """
        arr = []
        for word in text:
            arr.append(self.word_to_int(word))
        return np.array(arr)

    def arr_to_text(self, arr):
        """
        idx数组转句子
        :param arr:
        :return:
        """
        words = []
        for index in arr:
            words.append(self.int_to_word(index))
        return "".join(words)

    def save_to_file(self, filename):
        """
        保存为文件
        :param filename:
        :return:
        """
        with open(filename, 'wb') as f:
            pickle.dump(self.vocab, f)


def build_embed():
    """
    构建词汇表
    :return:
    """
    # 词嵌入数据源绝对路径
    dataset_path = '{}\\{}\\poetry.txt'.format(
        global_cfg.DATASET_ROOT, global_pool.config.embed.dataset
    )
    log_debug('\033[32membedding dataset path:{}\033[0m'.format(dataset_path))
    # 词嵌入模型路径
    embed_path = '{}\\{}\\{}'.format(
        global_cfg.DATASET_ROOT,
        global_pool.config.embed.dataset,
        global_pool.config.embed.model
    )
    log_debug('\033[32membedding path:{}\033[0m'.format(embed_path))

    with codecs.open(dataset_path, encoding='utf-8') as f:
        text = f.read()
    # 有则直接加载, 否则重新生成
    if os.path.exists(embed_path):
        embedding = VocabList(filename=embed_path)
    else:
        embedding = VocabList(text=text, max_vocab=5000)
        embedding.save_to_file(embed_path)
    return embedding
